package com.core.util

import org.apache.spark.Partitioner

/**
 * @Author 小聋包
 * @Date 2024/4/18 21:21
 * @Version 1.0
 */
class MyPartitioner(categoryIdTop10: List[String]) extends Partitioner {
  // 给每个 cid 配一个分区号(使用他们的索引就行了)
  private val cidAndIndex: Map[String, Int] = categoryIdTop10.zipWithIndex.toMap

  override def numPartitions: Int = categoryIdTop10.size

  override def getPartition(key: Any): Int = {
    key match {
      case (cid: Long, _) => cidAndIndex(cid.toString)
    }
  }
}
